# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Data for self attention UT."""

import numpy as np

SEQ_LEN = 2
BATCH_SIZE = 2
HIDDEN_SIZE = 32


def get_init_params(seq_len, batch_size, hidden_size, kv_hidden_size):
    """Generate initial parameters for SelfAttention"""
    np.random.seed(42)
    params = {
        "hidden_states": 0.01 * np.random.randn(seq_len, batch_size, hidden_size).astype(np.float32),
        "attention_mask": np.tril(np.ones(shape=(batch_size, 1, seq_len, seq_len))).astype(np.int8),
        "weight_qkv": 0.01 * np.random.randn(hidden_size + 2 * kv_hidden_size, hidden_size).astype(np.float32),
        "bias_qkv": 0.01 * np.random.randn(hidden_size + 2 * kv_hidden_size).astype(np.float32),
        "weight_proj": 0.01 * np.random.randn(hidden_size, hidden_size).astype(np.float32),
        "bias_proj": 0.01 * np.random.randn(hidden_size).astype(np.float32),
        "q_layernorm_weight": 0.01 * np.random.randn(hidden_size).astype(np.float32),
        "q_layernorm_bias": 0.01 * np.random.randn(hidden_size).astype(np.float32),
        "k_layernorm_weight": 0.01 * np.random.randn(kv_hidden_size).astype(np.float32),
        "k_layernorm_bias": 0.01 * np.random.randn(kv_hidden_size).astype(np.float32),
    }
    print(params["bias_proj"])
    return params


# 占位符，需要实际计算或从外部获取
GOLDEN_DATA = {
    "output_query_group_4": np.array(
        [[[-4.21647739e-04, 2.67271680e-04, -1.57212024e-04, 1.66020283e-04,
           -7.91710918e-05, 3.15163139e-04, -2.50073237e-04, 1.53204164e-04,
           3.76587966e-04, -9.23072861e-04, 4.28767555e-04, -3.34043987e-04,
           -6.14050732e-05, -1.64010387e-04, 1.33967347e-04, 4.13777219e-04,
           3.82365310e-04, -2.79074447e-04, -2.42814960e-04, -1.98864873e-05,
           1.66379017e-04, 2.40207330e-04, -5.62029076e-04, 6.32720301e-04,
           -3.30142648e-05, -3.02388100e-04, -3.94221745e-04, 3.36510333e-04,
           -8.09338671e-05, 2.49329518e-04, 5.58624743e-04, 6.31148869e-04],
          [-4.73979657e-04, 3.29846342e-04, -1.95527915e-04, 1.13577087e-04,
           -7.96524255e-05, 3.45273584e-04, -2.71982601e-04, 1.42623961e-04,
           3.96650343e-04, -9.76846903e-04, 4.77186579e-04, -3.48150148e-04,
           -1.38520758e-04, -1.54548747e-04, 1.59314033e-04, 4.08631517e-04,
           3.83883016e-04, -1.95456611e-04, -2.47606658e-04, -3.95881689e-05,
           1.86816658e-04, 2.41959817e-04, -6.53814001e-04, 5.85544622e-04,
           2.14241045e-05, -3.43817286e-04, -3.97544558e-04, 3.64110136e-04,
           -8.69510550e-05, 3.20769468e-04, 5.54460625e-04, 5.64362388e-04]],

         [[-4.32563014e-04, 2.90979835e-04, -1.52971057e-04, 1.60375319e-04,
           -7.36243819e-05, 3.44681321e-04, -2.72753503e-04, 1.48297011e-04,
           3.89526598e-04, -8.96966667e-04, 4.55134665e-04, -3.22765613e-04,
           -8.57864288e-05, -1.70633270e-04, 1.38077652e-04, 4.38710063e-04,
           3.86431988e-04, -2.47401214e-04, -2.34232211e-04, -3.21087282e-05,
           1.32673667e-04, 2.23902171e-04, -5.87977120e-04, 6.28675451e-04,
           -2.42550032e-05, -2.93377117e-04, -4.04996274e-04, 3.52784409e-04,
           -7.31592663e-05, 2.66829127e-04, 5.35687141e-04, 6.08674018e-04],
          [-4.71207080e-04, 3.10085510e-04, -1.83070850e-04, 1.10893605e-04,
           -9.89583205e-05, 3.62709688e-04, -2.55480612e-04, 1.48159641e-04,
           4.12498834e-04, -9.77750984e-04, 4.57207119e-04, -3.79265402e-04,
           -1.05268722e-04, -1.51554035e-04, 1.70723884e-04, 4.27430117e-04,
           3.94347764e-04, -2.20143600e-04, -2.50991725e-04, -4.93631997e-05,
           1.89635772e-04, 2.39642162e-04, -6.27847796e-04, 5.92195836e-04,
           -6.36474488e-06, -3.36142897e-04, -3.94580566e-04, 3.61132581e-04,
           -9.50240792e-05, 2.85849732e-04, 5.78690844e-04, 5.96937840e-04]]]),
    "output_query_group_8": np.array(
        [[[-2.2472338e-04, 4.4892236e-04, 5.9192424e-04, 3.2775174e-04,
           -1.9141105e-04, -7.1822474e-04, -2.4466100e-04, -5.0918222e-04,
           3.6290314e-04, -3.4275738e-04, -2.8467472e-04, -3.8527761e-04,
           8.2169188e-04, 3.1178875e-04, -1.0932739e-03, -1.3872179e-04,
           2.2414084e-04, 8.6626191e-05, -1.0176909e-03, -2.0129493e-04,
           -1.0703942e-04, 1.1569806e-03, 2.2872546e-04, -7.1597303e-04,
           3.2158772e-04, 4.2374039e-04, 1.4240046e-04, -3.3106090e-04,
           -1.7313687e-03, -3.6458942e-04, -4.2334365e-04, 9.3625725e-04],
          [-2.7430593e-04, 4.2216972e-04, 5.7084975e-04, 2.9147934e-04,
           -1.2742028e-04, -6.8236265e-04, -2.3248902e-04, -4.3632966e-04,
           2.4007766e-04, -3.1793397e-04, -2.6888086e-04, -4.2644801e-04,
           7.8148977e-04, 3.4356455e-04, -1.0211504e-03, -1.4180086e-04,
           2.3824123e-04, 1.7841172e-04, -9.9293527e-04, -1.9407798e-04,
           -1.8763087e-04, 1.0957000e-03, 2.1117380e-04, -6.4608007e-04,
           2.4210017e-04, 4.5302502e-04, 1.5192835e-04, -3.5992547e-04,
           -1.6004187e-03, -2.8465554e-04, -3.2311573e-04, 9.6602732e-04]],

         [[-2.4132521e-04, 4.5797296e-04, 5.5707787e-04, 3.3042984e-04,
           -1.8871749e-04, -7.3010207e-04, -2.5176996e-04, -5.2437640e-04,
           3.4348245e-04, -3.4017192e-04, -2.8676217e-04, -3.9425577e-04,
           8.3353330e-04, 3.0162121e-04, -1.1022381e-03, -1.5777114e-04,
           2.1856885e-04, 1.1532350e-04, -1.0146681e-03, -1.9876884e-04,
           -1.4150815e-04, 1.1425300e-03, 2.1436327e-04, -7.2148727e-04,
           3.0468346e-04, 4.3487156e-04, 1.2688893e-04, -3.4853906e-04,
           -1.7463891e-03, -3.2397313e-04, -4.1412702e-04, 9.5052837e-04],
          [-2.4978167e-04, 3.8561973e-04, 5.6573434e-04, 2.7176086e-04,
           -2.0329939e-04, -6.7764614e-04, -2.0067416e-04, -4.2722997e-04,
           2.8411823e-04, -3.3754398e-04, -2.7947137e-04, -3.9472463e-04,
           7.7866728e-04, 3.7953054e-04, -1.0514238e-03, -1.2808030e-04,
           2.1401311e-04, 1.4669952e-04, -9.8361552e-04, -1.9868695e-04,
           -1.5745701e-04, 1.1008645e-03, 1.9836229e-04, -6.5060926e-04,
           2.8522784e-04, 4.5993319e-04, 1.4265678e-04, -3.5586260e-04,
           -1.5962820e-03, -3.1004893e-04, -3.5172829e-04, 9.8165218e-04]]]),
    "bias_query_group_4": np.array(
        [0.00665417, 0.01941217, -0.00878323, -0.00378008, 0.00231446,
         0.0064649, -0.00215668, -0.00872961, 0.00881408, 0.00721135,
         -0.00916274, 0.01355443, 0.01170199, 0.00134296, 0.00079598,
         0.00554058, -0.00861703, 0.00030031, -0.02152384, 0.00876456,
         -0.01561493, 0.0150342, -0.0033012, -0.00211667, -0.00627734,
         -0.00288039, 0.01418531, -0.02487809, 0.01276965, 0.00338023,
         -0.01207022, -0.01075312]),
    "bias_query_group_8": np.array(
        [-0.00019874, -0.00107265, 0.01220283, -0.01076679, -0.00762519,
         0.01036303, -0.00595535, 0.00900533, 0.01699436, -0.00401795,
         0.00681362, -0.01153692, 0.00754733, 0.00199424, 0.00473038,
         -0.00645572, -0.01661083, 0.00503252, -0.02063403, -0.0031735,
         0.00941182, 0.00191898, 0.01948343, 0.01025109, -0.00653898,
         0.00027313, -0.00047371, -0.01582177, -0.00940399, -0.00560761,
         -0.02036221, 0.00047493]),
}

GPU_DATA = {
    "random": np.random.randn(SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE).astype(np.float32),
    "output_query_group_4": np.array(
        [[[-4.19616699e-04, 2.67028809e-04, -1.57356262e-04, 1.65939331e-04,
           -7.96318054e-05, 3.14712524e-04, -2.49862671e-04, 1.51634216e-04,
           3.77655029e-04, -9.23156738e-04, 4.29153442e-04, -3.31878662e-04,
           -6.15119934e-05, -1.64031982e-04, 1.35421753e-04, 4.11987305e-04,
           3.81469727e-04, -2.78472900e-04, -2.42233276e-04, -2.00271606e-05,
           1.66893005e-04, 2.39372253e-04, -5.60760498e-04, 6.33239746e-04,
           -3.40938568e-05, -3.03268433e-04, -3.92913818e-04, 3.37600708e-04,
           -8.10623169e-05, 2.49862671e-04, 5.56945801e-04, 6.29425049e-04],
          [-4.74929810e-04, 3.29971313e-04, -1.95503235e-04, 1.13010406e-04,
           -7.96318054e-05, 3.45230103e-04, -2.72750854e-04, 1.42097473e-04,
           3.96728516e-04, -9.76562500e-04, 4.78744507e-04, -3.47137451e-04,
           -1.39236450e-04, -1.54495239e-04, 1.61170959e-04, 4.08172607e-04,
           3.83377075e-04, -1.95503235e-04, -2.47955322e-04, -3.93390656e-05,
           1.87873840e-04, 2.42233276e-04, -6.56127930e-04, 5.83648682e-04,
           2.08616257e-05, -3.45230103e-04, -3.96728516e-04, 3.64303589e-04,
           -8.63075256e-05, 3.20434570e-04, 5.56945801e-04, 5.64575195e-04]],

         [[-4.31060791e-04, 2.88009644e-04, -1.51634216e-04, 1.59263611e-04,
           -7.48634338e-05, 3.43322754e-04, -2.70843506e-04, 1.47819519e-04,
           3.89099121e-04, -8.96453857e-04, 4.53948975e-04, -3.20434570e-04,
           -8.48770142e-05, -1.70707703e-04, 1.38282776e-04, 4.36782837e-04,
           3.85284424e-04, -2.47955322e-04, -2.33650208e-04, -3.29017639e-05,
           1.33514404e-04, 2.24113464e-04, -5.87463379e-04, 6.25610352e-04,
           -2.63452530e-05, -2.93731689e-04, -4.02450562e-04, 3.52859497e-04,
           -7.39097595e-05, 2.65121460e-04, 5.34057617e-04, 6.06536865e-04],
          [-4.73022461e-04, 3.08990479e-04, -1.83105469e-04, 1.10149384e-04,
           -1.00135803e-04, 3.64303589e-04, -2.55584717e-04, 1.46865845e-04,
           4.13894653e-04, -9.76562500e-04, 4.59671021e-04, -3.79562378e-04,
           -1.06334686e-04, -1.51634216e-04, 1.72615051e-04, 4.27246094e-04,
           3.94821167e-04, -2.20298767e-04, -2.51770020e-04, -4.98294830e-05,
           1.90734863e-04, 2.39372253e-04, -6.29425049e-04, 5.91278076e-04,
           -7.39097595e-06, -3.37600708e-04, -3.94821167e-04, 3.64303589e-04,
           -9.53674316e-05, 2.86102295e-04, 5.79833984e-04, 5.98907471e-04]]]),
    "output_query_group_8": np.array(
        [[[-2.2315979e-04, 4.4822693e-04, 5.9127808e-04, 3.2806396e-04,
           -1.9168854e-04, -7.2097778e-04, -2.4223328e-04, -5.0735474e-04,
           3.6048889e-04, -3.4523010e-04, -2.8610229e-04, -3.8719177e-04,
           8.2397461e-04, 3.1280518e-04, -1.0910034e-03, -1.4209747e-04,
           2.2315979e-04, 8.5353851e-05, -1.0147095e-03, -1.9931793e-04,
           -1.0538101e-04, 1.1596680e-03, 2.2792816e-04, -7.1716309e-04,
           3.2043457e-04, 4.2724609e-04, 1.4305115e-04, -3.2997131e-04,
           -1.7395020e-03, -3.6430359e-04, -4.2152405e-04, 9.3841553e-04],
          [-2.7275085e-04, 4.2152405e-04, 5.7220459e-04, 2.9373169e-04,
           -1.2588501e-04, -6.8283081e-04, -2.3078918e-04, -4.3487549e-04,
           2.3746490e-04, -3.1852722e-04, -2.6893616e-04, -4.2724609e-04,
           7.8582764e-04, 3.4523010e-04, -1.0223389e-03, -1.4400482e-04,
           2.3746490e-04, 1.8024445e-04, -9.9182129e-04, -1.9168854e-04,
           -1.8596649e-04, 1.0910034e-03, 2.1076202e-04, -6.4849854e-04,
           2.4223328e-04, 4.5394897e-04, 1.5354156e-04, -3.6048889e-04,
           -1.6021729e-03, -2.8419495e-04, -3.2234192e-04, 9.6511841e-04]],

         [[-2.4032593e-04, 4.5776367e-04, 5.5313110e-04, 3.3187866e-04,
           -1.8787384e-04, -7.3623657e-04, -2.4986267e-04, -5.2261353e-04,
           3.3950806e-04, -3.4523010e-04, -2.8991699e-04, -3.9672852e-04,
           8.3541870e-04, 3.0136108e-04, -1.0986328e-03, -1.6021729e-04,
           2.1839142e-04, 1.1396408e-04, -1.0147095e-03, -1.9836426e-04,
           -1.3828278e-04, 1.1444092e-03, 2.1266937e-04, -7.2097778e-04,
           3.0326843e-04, 4.3869019e-04, 1.2683868e-04, -3.4713745e-04,
           -1.7471313e-03, -3.2424927e-04, -4.1198730e-04, 9.5367432e-04],
          [-2.4795532e-04, 3.8337708e-04, 5.6838989e-04, 2.7275085e-04,
           -2.0408630e-04, -6.7901611e-04, -1.9645691e-04, -4.2343140e-04,
           2.8038025e-04, -3.3950806e-04, -2.8228760e-04, -3.9672852e-04,
           7.8201294e-04, 3.8146973e-04, -1.0528564e-03, -1.2969971e-04,
           2.1266937e-04, 1.4591217e-04, -9.8419189e-04, -1.9741058e-04,
           -1.5449524e-04, 1.0986328e-03, 1.9645691e-04, -6.4849854e-04,
           2.8800964e-04, 4.6348572e-04, 1.4400482e-04, -3.5476685e-04,
           -1.6021729e-03, -3.1280518e-04, -3.5095215e-04, 9.8419189e-04]]]),
    "bias_query_group_4": np.array(
        [0.00665283, 0.01940918, -0.00878906, -0.00378418, 0.00231934,
         0.00646973, -0.00215149, -0.00872803, 0.00878906, 0.00720215,
         -0.00915527, 0.0135498, 0.01171875, 0.00134277, 0.00079727,
         0.0055542, -0.00860596, 0.00029945, -0.02148438, 0.00878906,
         -0.015625, 0.01501465, -0.0032959, -0.00212097, -0.00628662,
         -0.00288391, 0.01416016, -0.02490234, 0.01275635, 0.00338745,
         -0.01208496, -0.01074219]),
    "bias_query_group_8": np.array(
        [-0.00019836, -0.00107574, 0.01220703, -0.01074219, -0.00762939,
         0.01037598, -0.00595093, 0.0090332, 0.01696777, -0.00402832,
         0.00680542, -0.01153564, 0.00753784, 0.0019989, 0.00473022,
         -0.00646973, -0.01660156, 0.0050354, -0.02062988, -0.00317383,
         0.00939941, 0.00192261, 0.01953125, 0.01025391, -0.00653076,
         0.00027275, -0.00047302, -0.01586914, -0.00939941, -0.00561523,
         -0.02038574, 0.00047493]),
}
